import torch
from torchvision import datasets, transforms
from torch.utils.data import random_split

def get_dataset(args):

    if args.dataset == "MNIST":
        use_augment = getattr(args, 'use_augment', False)
        magnitude = getattr(args, 'augment_magnitude', 4)
        
        if use_augment:
            print(f"Using augmentation for MNIST with magnitude {magnitude}")
            try:
                from torchvision.transforms import RandAugment
                train_transform = transforms.Compose([
                    RandAugment(num_ops=2, magnitude=magnitude),
                    transforms.ToTensor(),
                    transforms.RandomCrop(28, padding=4),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.Normalize((0.1307,), (0.3081,))
                ])
            except ImportError:
                print("RandAugment not available, using basic augmentation")
                train_transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.RandomCrop(28, padding=4),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.Normalize((0.1307,), (0.3081,))
                ])
        else:
            print("No augmentation for MNIST")
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])
        
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Pad(2),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        
        train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=train_transform)
        test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=test_transform)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batchsize, shuffle=False)
        return train_loader, test_loader
    
    if args.dataset == "FashionMNIST":
        use_augment = getattr(args, 'use_augment', False)
        magnitude = getattr(args, 'augment_magnitude', 4)
        
        if use_augment:
            print(f"Using augmentation for FashionMNIST with magnitude {magnitude}")
            try:
                from torchvision.transforms import RandAugment
                train_transform = transforms.Compose([
                    RandAugment(num_ops=2, magnitude=magnitude),
                    transforms.ToTensor(),
                    transforms.RandomCrop(28, padding=4),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.Normalize((0.1307,), (0.3081,))
                ])
            except ImportError:
                print("RandAugment not available, using basic augmentation")
                train_transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.RandomCrop(28, padding=4),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.Normalize((0.1307,), (0.3081,))
                ])
        else:
            print("No augmentation for FashionMNIST")
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])
        
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Pad(2),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        
        train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=train_transform)
        test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=test_transform)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batchsize, shuffle=False)
        return train_loader, test_loader
    
    if args.dataset == "FashionMNIST_CNN":
        use_augment = getattr(args, 'use_augment', False)
        magnitude = getattr(args, 'augment_magnitude', 4)
        
        if use_augment:
            print(f"Using augmentation for FashionMNIST_CNN with magnitude {magnitude}")
            try:
                from torchvision.transforms import RandAugment
                train_transform = transforms.Compose([
                    transforms.Resize(32),
                    RandAugment(num_ops=2, magnitude=magnitude),
                    transforms.ToTensor(),
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(p=0.5),
                ])
            except ImportError:
                print("RandAugment not available, using basic augmentation")
                train_transform = transforms.Compose([
                    transforms.Resize(32),
                    transforms.ToTensor(),
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(p=0.5),
                ])
        else:
            print("No augmentation for FashionMNIST_CNN")
            train_transform = transforms.Compose([
                transforms.Resize(32),
                transforms.ToTensor(),
            ])
        
        test_transform = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
        ])
        
        train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=train_transform)
        test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=test_transform)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batchsize, shuffle=False)
        return train_loader, test_loader
    
    if args.dataset == "MNIST_CNN":
        use_augment = getattr(args, 'use_augment', False)
        magnitude = getattr(args, 'augment_magnitude', 4)
        
        if use_augment:
            print(f"Using augmentation for MNIST_CNN with magnitude {magnitude}")
            try:
                from torchvision.transforms import RandAugment
                train_transform = transforms.Compose([
                    transforms.Resize(32),
                    RandAugment(num_ops=2, magnitude=magnitude),
                    transforms.ToTensor(),
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(p=0.5),
                ])
            except ImportError:
                print("RandAugment not available, using basic augmentation")
                train_transform = transforms.Compose([
                    transforms.Resize(32),
                    transforms.ToTensor(),
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(p=0.5),
                ])
        else:
            print("No augmentation for MNIST_CNN")
            train_transform = transforms.Compose([
                transforms.Resize(32),
                transforms.ToTensor(),
            ])

        test_transform = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
        ])

        train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=train_transform)
        test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=test_transform)
        
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batchsize, shuffle=False)
        
        return train_loader, test_loader
    
    if args.dataset == "STL10":
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        
        use_augment = getattr(args, 'data_augmentation', True)
        
        if use_augment:
            print(f"Using Randomcrop, horizontalflip for STL10")
            train_transform = transforms.Compose([
                transforms.RandomCrop(96, padding=8),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
                normalize
            ])
        else:
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                normalize
            ])
        
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ])
        
        train_dataset = datasets.STL10(root='./data', split='train+unlabeled', download=True, transform=train_transform)
        test_dataset = datasets.STL10(root='./data', split='test', download=True, transform=test_transform)
        
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batchsize, shuffle=False)
        
        return train_loader, test_loader

    if args.dataset == "STL10_cls":
        use_augment = getattr(args, 'use_augment', False)
        magnitude = getattr(args, 'augment_magnitude', 4)
        
        if use_augment:
            print(f"Using augmentation for STL10_cls with magnitude {magnitude}")
            try:
                from torchvision.transforms import RandAugment
                train_transform = transforms.Compose([
                    transforms.Resize(32),
                    RandAugment(num_ops=2, magnitude=magnitude),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.ToTensor(),
                ])
            except ImportError:
                print("RandAugment not available, using basic augmentation")
                train_transform = transforms.Compose([
                    transforms.Resize(32),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.ToTensor(),
                ])
        else:
            print("No augmentation for STL10_cls")
            train_transform = transforms.Compose([
                transforms.Resize(32),
                transforms.ToTensor(),
            ])

        test_transform = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
        ])

        train_dataset = datasets.STL10(root='./data', split='train', download=True, transform=train_transform)
        test_dataset = datasets.STL10(root='./data', split='test', download=True, transform=test_transform)

        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batchsize, shuffle=False)

        return train_loader, test_loader
    
    if args.dataset == 'CIFAR10':
        use_augment = getattr(args, 'use_augment', False)
        magnitude = getattr(args, 'augment_magnitude', 4)
        
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])

        if use_augment:
            print(f"Using RandAugment for CIFAR10 with magnitude {magnitude}")
            try:
                from torchvision.transforms import RandAugment
                train_transform = transforms.Compose([
                    RandAugment(num_ops=2, magnitude=magnitude),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, 4, padding_mode='edge'),
                    transforms.ToTensor(),
                    normalize,
                ])
            except ImportError:
                print("RandAugment not available, using basic augmentation")
                train_transform = transforms.Compose([
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, 4, padding_mode='edge'),
                    transforms.ToTensor(),
                    normalize,
                ])
        else:
            print("No augmentation for CIFAR10")
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])
        
        test_transform = transforms.Compose([
            transforms.ToTensor(), 
            normalize
        ])
        
        train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
        test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batchsize, shuffle=False, pin_memory=True)
        return train_loader, test_loader
    

    if args.dataset == 'CIFAR100':
        use_augment = getattr(args, 'use_augment', False)
        magnitude = getattr(args, 'augment_magnitude', 4)
        
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])

        if use_augment:
            print(f"Using RandAugment for CIFAR100 with magnitude {magnitude}")
            try:
                from torchvision.transforms import RandAugment
                train_transform = transforms.Compose([
                    RandAugment(num_ops=2, magnitude=magnitude),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, 4, padding_mode='edge'),
                    transforms.ToTensor(),
                    normalize,
                ])
            except ImportError:
                print("RandAugment not available, using basic augmentation")
                train_transform = transforms.Compose([
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, 4, padding_mode='edge'),
                    transforms.ToTensor(),
                    normalize,
                ])
        else:
            print("No augmentation for CIFAR100")
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])
        
        test_transform = transforms.Compose([
            transforms.ToTensor(), 
            normalize
        ])
        
        train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
        test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batchsize, shuffle=False, pin_memory=True)
        return train_loader, test_loader

    if args.dataset == 'SVHN':
        use_augment = getattr(args, 'use_augment', False)
        magnitude = getattr(args, 'augment_magnitude', 4)
        
        normalize = transforms.Normalize(mean=[0.4376821, 0.4437697, 0.47280442],
                                    std=[0.19803012, 0.20101562, 0.19703614])

        if use_augment:
            print(f"Using RandAugment for SVHN with magnitude {magnitude}")
            try:
                from torchvision.transforms import RandAugment
                train_transform = transforms.Compose([
                    RandAugment(num_ops=2, magnitude=magnitude),
                    transforms.RandomCrop(32, 4, padding_mode='edge'),
                    transforms.ToTensor(),
                    normalize,
                ])
            except ImportError:
                print("RandAugment not available, using basic augmentation")
                train_transform = transforms.Compose([
                    transforms.RandomCrop(32, 4, padding_mode='edge'),
                    transforms.ToTensor(),
                    normalize,
                ])
        else:
            print("No augmentation for SVHN")
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])
        
        test_transform = transforms.Compose([
            transforms.ToTensor(), 
            normalize
        ])
        
        train_dataset = datasets.SVHN(root='./data', split='train', download=True, transform=train_transform)
        test_dataset = datasets.SVHN(root='./data', split='test', download=True, transform=test_transform)
        
        train_dataset.labels[train_dataset.labels == 10] = 0
        test_dataset.labels[test_dataset.labels == 10] = 0
        
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batchsize, shuffle=False, pin_memory=True)
        return train_loader, test_loader

    if args.dataset == 'CIFAR10_spiking':
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomCrop(32, padding=4, padding_mode='edge'),
        ])
        
        if hasattr(args, 'use_randaugment') and args.use_randaugment:
            from torchvision.transforms import RandAugment
            magnitude = getattr(args, 'randaugment_magnitude', 4)
            num_ops = getattr(args, 'randaugment_num_ops', 2)
            train_transform.transforms.insert(-1, RandAugment(num_ops=num_ops, magnitude=magnitude))
        
        test_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        
        train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
        test_dataset = datasets.CIFAR10(root='./data', train=False, download=False, transform=test_transform)
        
        num_workers = getattr(args, 'num_workers', 4)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, 
                                                 pin_memory=True, num_workers=num_workers)
        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batchsize, shuffle=False, 
                                                 pin_memory=True, num_workers=num_workers)
        return train_loader, test_loader

    if args.dataset == "scene_parse_150":
        from torch.utils.data import TensorDataset, DataLoader
        train_annotations = torch.load("./data/scene_parse_data/train_data_annotations.pt")
        train_images = torch.load("./data/scene_parse_data/train_data_images.pt")
        train_dataset = TensorDataset(train_images, train_annotations)
        train_loader = DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True)

        test_annotations = torch.load("./data/scene_parse_data/test_data_annotations.pt")
        test_images  = torch.load("./data/scene_parse_data/test_data_images.pt")
        test_dataset = TensorDataset(test_images, test_annotations)
        test_loader = DataLoader(test_dataset, batch_size=args.batchsize, shuffle=False)
        
        return train_loader, test_loader
